{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basics of lattice models\n",
    "In this notebook, we'll explain a lattice model, an interpolated lookup table.\n",
    "In addition, we'll show how monotonicity and smooth regularizers can change the model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we need to import libraries we're going to use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_lattice as tfl\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "from matplotlib.ticker import LinearLocator, FormatStrFormatter\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lattice model visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let us define helper functions for visualizing the surface of 2d lattice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hypercube (multilinear) interpolation in a 2 x 2 lattice.\n",
    "# params[0] == lookup value at (0, 0)\n",
    "# params[1] == lookup value at (0, 1)\n",
    "# params[2] == lookup value at (1, 0)\n",
    "# params[3] == lookup value at (1, 1)\n",
    "def twod(x1, x2, params):\n",
    "    y = ((1 - x1) * (1 - x2) * params[0]\n",
    "         + (1 - x1) * x2 * params[1]\n",
    "         + x1 * (1 - x2) * params[2]\n",
    "         + x1 * x2 * params[3])\n",
    "    return y\n",
    "\n",
    "# This function will generate 3d plot for lattice function values.\n",
    "# params uniquely characterizes the lattice lookup values.\n",
    "def lattice_surface(params):\n",
    "    print('Lattice params:')\n",
    "    print(params)\n",
    "    \n",
    "    %matplotlib inline\n",
    "    fig = plt.figure()\n",
    "    ax = fig.gca(projection='3d')\n",
    "\n",
    "    # Make data.\n",
    "    n = 50\n",
    "    xv, yv = np.meshgrid(np.linspace(0.0, 1.0, num=n),\n",
    "                         np.linspace(0.0, 1.0, num=n))\n",
    "    zv = np.zeros([n, n])\n",
    "    for k1 in range(n):\n",
    "        for k2 in range(n):\n",
    "            zv[k1, k2] = twod(xv[k1, k2], yv[k1, k2], params)\n",
    "\n",
    "    # Plot the surface.\n",
    "    surf = ax.plot_surface(xv, yv, zv, cmap=cm.coolwarm)\n",
    "    # Customize the z axis.\n",
    "    ax.set_zlim(0.0, 1.0)\n",
    "    ax.zaxis.set_major_locator(LinearLocator(10))\n",
    "    ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))\n",
    "\n",
    "    # Add a color bar which maps values to colors.\n",
    "    fig.colorbar(surf, shrink=0.5, aspect=5,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's draw a surface of 2d lattice model.\n",
    "This model represents an \"XOR\" function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This will plot the surface plot.\n",
    "lattice_surface([0.0, 1.0, 1.0, 0.0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train XOR function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll provide a synthetic data that represents the \"XOR\" function, that is\n",
    "\n",
    "f(0, 0) = 0\n",
    "f(0, 1) = 1\n",
    "f(1, 0) = 1\n",
    "f(1, 1) = 0\n",
    "\n",
    "and check whether a lattice can __learn__ this function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reset the graph.\n",
    "tf.reset_default_graph()\n",
    "\n",
    "# Prepare the dataset.\n",
    "x_data = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]\n",
    "y_data = [[0.0], [1.0], [1.0], [0.0]]\n",
    "\n",
    "# Define placeholders.\n",
    "x = tf.placeholder(dtype=tf.float32, shape=(None, 2))\n",
    "y_ = tf.placeholder(dtype=tf.float32, shape=(None, 1))\n",
    "\n",
    "# 2 x 2 lattice with 1 output.\n",
    "# lattice_param is [output_dim, 4] tensor.\n",
    "lattice_sizes = [2, 2]\n",
    "(y, lattice_param, _, _) = tfl.lattice_layer(\n",
    "    x, lattice_sizes=[2, 2], output_dim=1)\n",
    "\n",
    "# Sqaured loss\n",
    "loss = tf.reduce_mean(tf.square(y - y_))\n",
    "\n",
    "# Minimize!\n",
    "train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)\n",
    "\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "# Iterate 100 times\n",
    "for _ in range(100):\n",
    "    sess.run(train_op, feed_dict={x: x_data, y_: y_data})\n",
    "\n",
    "# Fetching trained lattice parameter.\n",
    "lattice_param_val = sess.run(lattice_param)\n",
    "# Draw the surface!\n",
    "lattice_surface(lattice_param_val[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train with monotonicity\n",
    "Now we'll set monotonicity in a lattice model. We'll use the same synthetic data generated by \"XOR\" function, but now we'll require full monotonicity in both directions, x1 and x2. Note that the data does not contain monotonicity, since \"XOR\" function value decreases, i.e., f(1, 0) > f(1, 1) and f(0, 1) > f(1, 1).\n",
    "So the trained model will do its best to fit the data while satisfying the monotonicity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "x_data = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]\n",
    "y_data = [[0.0], [1.0], [1.0], [0.0]]\n",
    "\n",
    "x = tf.placeholder(dtype=tf.float32, shape=(None, 2))\n",
    "y_ = tf.placeholder(dtype=tf.float32, shape=(None, 1))\n",
    "\n",
    "# 2 x 2 lattice with 1 output.\n",
    "# lattice_param is [output_dim, 4] tensor.\n",
    "lattice_sizes = [2, 2]\n",
    "(y, lattice_param, projection_op, _) = tfl.lattice_layer(\n",
    "    x, lattice_sizes=[2, 2], output_dim=1, is_monotone=True)\n",
    "\n",
    "# Sqaured loss\n",
    "loss = tf.reduce_mean(tf.square(y - y_))\n",
    "\n",
    "# Minimize!\n",
    "train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)\n",
    "\n",
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "# Iterate 100 times\n",
    "for _ in range(100):\n",
    "    # Apply gradient.\n",
    "    sess.run(train_op, feed_dict={x: x_data, y_: y_data})\n",
    "    # Then projection.\n",
    "    sess.run(projection_op)\n",
    "\n",
    "# Fetching trained lattice parameter.\n",
    "lattice_param_val = sess.run(lattice_param)\n",
    "# Draw it!\n",
    "# You can see that the prediction does not decrease.\n",
    "lattice_surface(lattice_param_val[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train with partial monotonicity\n",
    "Now we'll set partial monotonicity. Here only one input is constrained to be monotonic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "x_data = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]\n",
    "y_data = [[0.0], [1.0], [1.0], [0.0]]\n",
    "\n",
    "x = tf.placeholder(dtype=tf.float32, shape=(None, 2))\n",
    "y_ = tf.placeholder(dtype=tf.float32, shape=(None, 1))\n",
    "\n",
    "# 2 x 2 lattice with 1 output.\n",
    "# lattice_param is [output_dim, 4] tensor.\n",
    "lattice_sizes = [2, 2]\n",
    "(y, lattice_param, projection_op, _) = tfl.lattice_layer(\n",
    "    x, lattice_sizes=[2, 2], output_dim=1, is_monotone=[True, False])\n",
    "\n",
    "# Sqaured loss\n",
    "loss = tf.reduce_mean(tf.square(y - y_))\n",
    "\n",
    "# Minimize!\n",
    "train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)\n",
    "\n",
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "# Iterate 100 times\n",
    "for _ in range(100):\n",
    "    # Apply gradient.\n",
    "    sess.run(train_op, feed_dict={x: x_data, y_: y_data})\n",
    "    # Then projection.\n",
    "    sess.run(projection_op)\n",
    "\n",
    "# Fetching trained lattice parameter.\n",
    "lattice_param_val = sess.run(lattice_param)\n",
    "# Draw it!\n",
    "# You can see that the prediction does not decrease in one direction.\n",
    "lattice_surface(lattice_param_val[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training OR function\n",
    "Now we switch to a synthetic dataset generated by \"OR\" function to illustrate other regularizers.\n",
    "f(0, 0) = 0\n",
    "f(0, 1) = 1\n",
    "f(1, 0) = 1\n",
    "f(1, 1) = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "x_data = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]\n",
    "y_data = [[0.0], [1.0], [1.0], [1.0]]\n",
    "\n",
    "x = tf.placeholder(dtype=tf.float32, shape=(None, 2))\n",
    "y_ = tf.placeholder(dtype=tf.float32, shape=(None, 1))\n",
    "\n",
    "# 2 x 2 lattice with 1 output.\n",
    "# lattice_param is [output_dim, 4] tensor.\n",
    "lattice_sizes = [2, 2]\n",
    "(y, lattice_param, _, _) = tfl.lattice_layer(\n",
    "    x, lattice_sizes=[2, 2], output_dim=1)\n",
    "\n",
    "# Sqaured loss\n",
    "loss = tf.reduce_mean(tf.square(y - y_))\n",
    "\n",
    "# Minimize!\n",
    "train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)\n",
    "\n",
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "# Iterate 100 times\n",
    "for _ in range(100):\n",
    "    # Apply gradient.\n",
    "    sess.run(train_op, feed_dict={x: x_data, y_: y_data})\n",
    "\n",
    "# Fetching trained lattice parameter.\n",
    "lattice_param_val = sess.run(lattice_param)\n",
    "# Draw it!\n",
    "lattice_surface(lattice_param_val[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Laplacian regularizer\n",
    "Laplacian regularizer puts the penalty on lookup value changes. In other words, it tries to make the slope of each face as small as possible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "x_data = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]\n",
    "y_data = [[0.0], [1.0], [1.0], [1.0]]\n",
    "\n",
    "x = tf.placeholder(dtype=tf.float32, shape=(None, 2))\n",
    "y_ = tf.placeholder(dtype=tf.float32, shape=(None, 1))\n",
    "\n",
    "# 2 x 2 lattice with 1 output.\n",
    "# lattice_param is [output_dim, 4] tensor.\n",
    "lattice_sizes = [2, 2]\n",
    "(y, lattice_param, _, regularization) = tfl.lattice_layer(\n",
    "    x, lattice_sizes=[2, 2], output_dim=1, l2_laplacian_reg=[0.0, 1.0])\n",
    "\n",
    "# Sqaured loss\n",
    "loss = tf.reduce_mean(tf.square(y - y_))\n",
    "loss += regularization\n",
    "\n",
    "# Minimize!\n",
    "train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)\n",
    "\n",
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "# Iterate 100 times\n",
    "for _ in range(1000):\n",
    "    # Apply gradient.\n",
    "    sess.run(train_op, feed_dict={x: x_data, y_: y_data})\n",
    "\n",
    "# Fetching trained lattice parameter.\n",
    "lattice_param_val = sess.run(lattice_param)\n",
    "# Draw it!\n",
    "# With heavy Laplacian regularization along the second axis, the second axis's slope becomes zero.\n",
    "lattice_surface(lattice_param_val[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Torsion regularizer\n",
    "Torsion regularizer penalizes nonlinear interactions in the feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "x_data = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]\n",
    "y_data = [[0.0], [1.0], [1.0], [1.0]]\n",
    "\n",
    "x = tf.placeholder(dtype=tf.float32, shape=(None, 2))\n",
    "y_ = tf.placeholder(dtype=tf.float32, shape=(None, 1))\n",
    "\n",
    "# 2 x 2 lattice with 1 output.\n",
    "# lattice_param is [output_dim, 4] tensor.\n",
    "lattice_sizes = [2, 2]\n",
    "(y, lattice_param, _, regularization) = tfl.lattice_layer(\n",
    "    x, lattice_sizes=[2, 2], output_dim=1, l2_torsion_reg=1.0)\n",
    "\n",
    "# Sqaured loss\n",
    "loss = tf.reduce_mean(tf.square(y - y_))\n",
    "loss += regularization\n",
    "\n",
    "# Minimize!\n",
    "train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)\n",
    "\n",
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "# Iterate 1000 times\n",
    "for _ in range(1000):\n",
    "    # Apply gradient.\n",
    "    sess.run(train_op, feed_dict={x: x_data, y_: y_data})\n",
    "\n",
    "# Fetching trained lattice parameter.\n",
    "lattice_param_val = sess.run(lattice_param)\n",
    "# Draw it!\n",
    "# With heavy Torsion regularization, the model becomes a linear model.\n",
    "lattice_surface(lattice_param_val[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf-lattice",
   "language": "python",
   "name": "tf-lattice"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
